import numpy as np
import cv2
import matplotlib.pyplot as plt
import math


def dctTransform(img):
	img = img.astype(np.float64)
	img_dct = cv2.dct(img)
	magnitude = 20*np.log(abs(img_dct))


	print(abs(img_dct)[0:20,0:20])

	img_dct_deal = img_dct.copy()
	img_dct_deal[img_dct_deal >= 5000] = 0

	image_idct = cv2.idct(img_dct)
	image_deal_idct = cv2.idct(img_dct_deal)

	plt.subplot(221),plt.imshow(img, cmap = 'gray')
	plt.title('Input Image'), plt.xticks([]), plt.yticks([])
	plt.subplot(222),plt.imshow(magnitude)
	plt.title('Input Image'), plt.xticks([]), plt.yticks([])
	plt.subplot(223),plt.imshow(image_idct, cmap = 'gray')
	plt.title('Input Image'), plt.xticks([]), plt.yticks([])
	plt.subplot(224),plt.imshow(image_deal_idct, cmap = 'gray')
	plt.title('Input Image'), plt.xticks([]), plt.yticks([])
	plt.show()


if __name__ == '__main__':
	img = cv2.imread("../images/airplane.png", 0)
	salt_img = cv2.imread("../images/circuit-board-salt.tif", 0)
	dctTransform(img)